// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "ck_tile/host.hpp"
#include "rmsnorm2d_fwd.hpp"
#include <cstring>
#include "ck_tile/utility/json_dump.hpp"

// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
    double rtol = 1e-2;
    double atol = 1e-2;
    return ck_tile::make_tuple(rtol, atol);
}

template <>
auto get_elimit<ck_tile::bf16_t>()
{
    double rtol = 1e-2;
    double atol = 1e-2;
    return ck_tile::make_tuple(rtol, atol);
}

template <>
auto get_elimit<ck_tile::int8_t>()
{
    double rtol = 1e-02;
    double atol = 1.0;
    return ck_tile::make_tuple(rtol, atol);
}

auto create_args(int argc, char* argv[])
{
    ck_tile::ArgParser arg_parser;
    arg_parser.insert("m", "3328", "m dimension")
        .insert("n", "4096", "n dimension")
        .insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
        .insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
        .insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
        .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
        .insert("e", "1e-5", "epsilon")
        .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case")
        .insert("save_unquant", "0", "save result before quant")
        .insert("v", "1", "cpu validation or not")
        .insert("kname", "1", "print kernel name or not")
        .insert("prec_i", "fp16", "input precision")
        .insert("prec_o", "auto", "output precision, set auto will be the same as input")
        .insert("prec_sm",
                "auto",
                "output quant scale type, set auto will use fp32. used when fquant=1")
        .insert("prec_sy",
                "auto",
                "output quant scale type, set auto will use fp32. used when fquant=1 or 2")
        .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
        .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
        .insert("warmup", "5", "cold iter")
        .insert("repeat", "20", "hot iter")
        .insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model")
        .insert("json", "0", "0: No Json, 1: Dump Results in Json format")
        .insert("jsonfile", "rmsnorm2d_fwd.json", "json file name to dump results");

    bool result = arg_parser.parse(argc, argv);
    return std::make_tuple(result, arg_parser);
}

template <typename InDataType,
          typename OutDataType,
          typename SmoothScaleDataType,
          typename YScaleDataType,
          bool SaveRms,
          bool SaveUnquant>
bool run(const ck_tile::ArgParser& arg_parser)
{
    ck_tile::index_t m              = arg_parser.get_int("m");
    ck_tile::index_t n              = arg_parser.get_int("n");
    float epsilon                   = arg_parser.get_float("e");
    int kname                       = arg_parser.get_int("kname");
    int do_validation               = arg_parser.get_int("v");
    int fused_add                   = arg_parser.get_int("fadd");
    int fused_quant                 = arg_parser.get_int("fquant");
    int warmup                      = arg_parser.get_int("warmup");
    int repeat                      = arg_parser.get_int("repeat");
    int use_model_sensitive_rmsnorm = arg_parser.get_int("s");

    ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
    if(x_stride < 0)
        x_stride = n;
    ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
    if(xr_stride < 0)
        xr_stride = n;
    ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
    if(y_stride < 0)
        y_stride = n;
    ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
    if(yr_stride < 0)
        yr_stride = n;
    assert(x_stride >= n);

    std::string prec_i  = arg_parser.get_str("prec_i");
    std::string prec_o  = arg_parser.get_str("prec_o");
    std::string prec_sm = arg_parser.get_str("prec_sm");
    std::string prec_sy = arg_parser.get_str("prec_sy");
    if(prec_o == "auto")
    {
        prec_o = prec_i;
    }
    if(prec_sm == "auto")
    {
        prec_sm = "fp32";
    }
    if(prec_sy == "auto")
    {
        prec_sy = "fp32";
    }

    if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8" && prec_o != "fp8")
    {
        std::cout
            << "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
            << std::endl;
        return false;
    }

    if((fused_quant == 0) && SaveUnquant)
    {
        std::cout
            << "save_unquant should be 0 if quant output is not enabled because it is meaningless. "
            << "Output Y is what wanted." << std::endl;
        return false;
    }

    using TypeConfig =
        RmsnormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;

    using XDataType         = typename TypeConfig::XDataType;
    using YDataType         = typename TypeConfig::YDataType;
    using GammaDataType     = typename TypeConfig::GammaDataType;
    using XResidualDataType = XDataType;
    using YResidualDataType = XDataType;

    using InvRmsDataType =
        std::conditional_t<SaveRms, typename TypeConfig::InvRmsDataType, ck_tile::null_type>;
    using UnquantYDataType =
        std::conditional_t<SaveUnquant, typename TypeConfig::UnquantYDataType, ck_tile::null_type>;

    using ComputeDataType = typename TypeConfig::ComputeDataType;

    // host verify
    ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
    ck_tile::HostTensor<GammaDataType> gamma_host({n});
    ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
    ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});

    ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
    ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});

    ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
    ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
    ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
    ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});

    ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});

    ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n}, {y_stride, 1});
    ck_tile::HostTensor<UnquantYDataType> unquant_y_host_dev({m, n}, {y_stride, 1});
    ck_tile::HostTensor<ck_tile::null_type> unquant_y_null({1});

    ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
    ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
    ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
    ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);

    ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
    ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
    ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
    ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
    ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
    ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
    ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
    ck_tile::DeviceMem unquant_y_buf(unquant_y_host_dev.get_element_space_size_in_bytes());

    x_buf.ToDevice(x_host.data());
    gamma_buf.ToDevice(gamma_host.data());
    x_residual_buf.ToDevice(x_residual_host.data());
    sm_scale_buf.ToDevice(sm_scale_host.data());

    auto prec_str = [&]() {
        auto base_str = prec_i;
        if(prec_i != prec_o)
        {
            base_str += "|" + prec_o;
        }
        if(fused_quant == 1)
        {
            base_str += std::string("(") + prec_sy + ")";
        }
        return base_str;
    }();

    if(n > 8192)
    {
        use_model_sensitive_rmsnorm = 0;
    }

    std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
              << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
              << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;

    rmsnorm2d_fwd_traits traits{prec_i,
                                prec_o,
                                prec_sm,
                                prec_sy,
                                SaveRms,
                                SaveUnquant,
                                fused_add,
                                fused_quant,
                                use_model_sensitive_rmsnorm};

    rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
                            fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
                            fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
                            gamma_buf.GetDeviceBuffer(),
                            y_buf.GetDeviceBuffer(),
                            fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
                            fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
                            nullptr, // p_invRms, unsupported yet
                            SaveUnquant ? unquant_y_buf.GetDeviceBuffer() : nullptr,
                            epsilon,
                            m,
                            n,
                            x_stride,   // x row_stride
                            xr_stride,  // x residule row stride
                            y_stride,   // y row stride
                            yr_stride}; // y residule row stride

    float ave_time = rmsnorm2d_fwd(
        traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});

    std::size_t num_byte =
        sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n;
    num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0;
    num_byte += SaveUnquant ? sizeof(UnquantYDataType) * m * n : 0;
    num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0;
    num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0;
    num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0;

    float gb_per_sec = num_byte / 1.E6 / ave_time;
    std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;

    bool pass = true;

    if(do_validation)
    {
        // reference
        if(fused_add != 0)
        {
            // fused pre_add/pre_add_store
            // TODO we accumulate directly to x_host for simplcity here...
            std::transform(x_host.mData.cbegin(),
                           x_host.mData.cend(),
                           x_residual_host.mData.cbegin(),
                           x_host.mData.begin(),
                           [](auto x_, auto r_) {
                               auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
                                         ck_tile::type_convert<ComputeDataType>(r_);
                               return ck_tile::type_convert<XDataType>(o_);
                           });
        }

        if(fused_quant != 0)
        {
            auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
                int N_ = acc_.mDesc.get_lengths()[1];
                if(fused_quant == 1)
                {
                    for(int n_ = 0; n_ < N_; n_++)
                    {
                        // input smooth outlier
                        acc_(m_, n_) = acc_(m_, n_) *
                                       ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
                    }
                }
                ComputeDataType absmax = static_cast<ComputeDataType>(0);
                for(int n_ = 0; n_ < N_; n_++)
                {
                    const auto a = ck_tile::abs(acc_(m_, n_));
                    absmax       = a > absmax ? a : absmax;
                }
                // printf("cpu:absmax:%f\n", absmax);
                constexpr ComputeDataType kMaxY =
                    std::is_same<YDataType, ck_tile::fp8_t>::value    ? 240.0
                    : std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
                                                                      : 0.0;
                ComputeDataType y_scale = absmax / kMaxY;
                y_scale_host_ref(m_)    = ck_tile::type_convert<YScaleDataType>(y_scale);
                for(int n_ = 0; n_ < N_; n_++)
                {
                    o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
                }
            };

            auto default_and_dquant_functor = [&](int m_, auto& o_unquant_, auto& o_, auto& acc_) {
                const int N = acc_.mDesc.get_lengths()[1];
                for(int n_ = 0; n_ < N; ++n_)
                {
                    o_unquant_(m_, n_) = ck_tile::type_convert<UnquantYDataType>(acc_(m_, n_));
                }

                dquant_functor(m_, o_, acc_);
            };

            if constexpr(SaveUnquant)
            {
                ck_tile::reference_rmsnorm2d_fwd<XDataType,
                                                 GammaDataType,
                                                 ComputeDataType,
                                                 YDataType,
                                                 InvRmsDataType,
                                                 UnquantYDataType>(x_host,
                                                                   gamma_host,
                                                                   y_host_ref,
                                                                   invRms_host_ref,
                                                                   unquant_y_host_ref,
                                                                   epsilon,
                                                                   default_and_dquant_functor,
                                                                   use_model_sensitive_rmsnorm);
            }
            else
            {
                ck_tile::reference_rmsnorm2d_fwd<XDataType,
                                                 GammaDataType,
                                                 ComputeDataType,
                                                 YDataType,
                                                 InvRmsDataType,
                                                 UnquantYDataType>(x_host,
                                                                   gamma_host,
                                                                   y_host_ref,
                                                                   invRms_host_ref,
                                                                   unquant_y_host_ref,
                                                                   epsilon,
                                                                   dquant_functor,
                                                                   use_model_sensitive_rmsnorm);
            }
        }
        else
        {
            assert(SaveUnquant == false);
            ck_tile::reference_rmsnorm2d_fwd<XDataType,
                                             GammaDataType,
                                             ComputeDataType,
                                             YDataType,
                                             InvRmsDataType,
                                             ck_tile::null_type>(
                x_host,
                gamma_host,
                y_host_ref,
                invRms_host_ref,
                unquant_y_null,
                epsilon,
                ck_tile::reference_rmsnorm2d_default_epilogue{},
                use_model_sensitive_rmsnorm);
        }

        y_buf.FromDevice(y_host_dev.data());

        ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
        if(fused_add == 1)
        {
            y_residual_buf.FromDevice(y_residual_host_dev.data());
        }

        if constexpr(SaveUnquant)
        {
            unquant_y_buf.FromDevice(unquant_y_host_dev.data());
        }

        auto [rtol, atol] = get_elimit<YDataType>();
        if(x_stride == n)
        {
            pass = ck_tile::check_err(
                y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol);

            if constexpr(SaveUnquant)
            {
                pass &= ck_tile::check_err(unquant_y_host_dev,
                                           unquant_y_host_ref,
                                           std::string("\n OUT ERROR: Incorrect unquant results!"),
                                           rtol,
                                           atol);
            }

            if(fused_add == 1)
            {
                pass &= ck_tile::check_err(y_residual_host_dev,
                                           x_host,
                                           std::string("\nADD Error: Incorrect results!"),
                                           rtol,
                                           atol);
            }
        }
        else
        {
            for(int i_r = 0; i_r < m; i_r++)
            {
                std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
                                                      y_host_dev.begin() + i_r * y_stride + n);
                std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
                                                      y_host_ref.begin() + i_r * y_stride + n);
                pass &= ck_tile::check_err(y_host_dev_row,
                                           y_host_ref_row,
                                           std::string("\nOUT[") + std::to_string(i_r) +
                                               std::string("] Error: Incorrect results!"),
                                           rtol,
                                           atol);

                if(fused_add == 1)
                {
                    std::vector<YResidualDataType> y_residual_host_dev_row(
                        y_residual_host_dev.begin() + i_r * yr_stride,
                        y_residual_host_dev.begin() + i_r * yr_stride + n);
                    std::vector<YResidualDataType> y_residual_host_ref_row(
                        x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
                    pass &= ck_tile::check_err(y_residual_host_dev_row,
                                               y_residual_host_ref_row,
                                               std::string("\nADD[") + std::to_string(i_r) +
                                                   std::string("] Error: Incorrect results!"),
                                               rtol,
                                               atol);
                }

                if constexpr(SaveUnquant)
                {
                    std::vector<UnquantYDataType> unquant_y_host_dev_row(
                        unquant_y_host_dev.begin() + i_r * y_stride,
                        unquant_y_host_dev.begin() + i_r * y_stride + n);
                    std::vector<UnquantYDataType> unquant_y_host_ref_row(
                        unquant_y_host_ref.begin() + i_r * y_stride,
                        unquant_y_host_ref.begin() + i_r * y_stride + n);
                    pass &=
                        ck_tile::check_err(unquant_y_host_dev_row,
                                           unquant_y_host_ref_row,
                                           std::string("\nOUT[") + std::to_string(i_r) +
                                               std::string("] Error: Incorrect unquant y results!"),
                                           rtol,
                                           atol);
                }
            }
        }

        if(fused_quant == 1)
        {
            y_scale_buf.FromDevice(y_scale_host_dev.data());
            pass &= ck_tile::check_err(y_scale_host_dev,
                                       y_scale_host_ref,
                                       std::string("\nSCALE Error: Incorrect results!"),
                                       rtol,
                                       atol);
        }

        std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
    }

    if(arg_parser.get_int("json") == 1)
    {
        dump_rmsnorm2d_fwd_json(arg_parser.get_str("jsonfile"),
                                prec_str,
                                m,
                                n,
                                x_stride,
                                xr_stride,
                                y_stride,
                                yr_stride,
                                use_model_sensitive_rmsnorm,
                                ave_time,
                                0,
                                gb_per_sec,
                                pass);
    }

    return pass;
}

bool is_quant_data_type(const std::string& prec) { return (prec == "int8") || (prec == "fp8"); }

int main(int argc, char* argv[])
{
    auto [result, arg_parser] = create_args(argc, argv);
    if(!result)
        return -1;

    std::string prec_i  = arg_parser.get_str("prec_i");
    std::string prec_o  = arg_parser.get_str("prec_o");
    std::string prec_sm = arg_parser.get_str("prec_sm");
    std::string prec_sy = arg_parser.get_str("prec_sy");
    if(prec_o == "auto")
    {
        prec_o = prec_i;
    }
    if(prec_sm == "auto")
    {
        prec_sm = "fp32";
    }
    if(prec_sy == "auto")
    {
        prec_sy = "fp32";
    }

    int save_rms    = arg_parser.get_int("save_rms");
    int fused_quant = arg_parser.get_int("fquant");
    int save_unquant =
        arg_parser.get_int("save_unquant") && is_quant_data_type(prec_o) && (fused_quant != 0);

    if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms)
    {
        return run<ck_tile::half_t, ck_tile::half_t, float, float, true, false>(arg_parser) ? 0
                                                                                            : -2;
    }
    else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms)
    {
        return run<ck_tile::half_t, ck_tile::half_t, float, float, false, false>(arg_parser) ? 0
                                                                                             : -2;
    }
    else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
            save_rms)
    {
        return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true, false>(arg_parser) ? 0
                                                                                            : -2;
    }
    else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms)
    {
        return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, false, false>(arg_parser) ? 0
                                                                                             : -2;
    }

    // dynamic quant case, only in inference
    else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms && !save_unquant)
    {
        return run<ck_tile::half_t, ck_tile::int8_t, float, float, true, false>(arg_parser) ? 0
                                                                                            : -2;
    }
    else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms && !save_unquant)
    {
        return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true, false>(arg_parser) ? 0
                                                                                            : -2;
    }
    else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms && !save_unquant)
    {
        return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false, false>(arg_parser) ? 0
                                                                                            : -2;
    }
    else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms && !save_unquant)
    {
        return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false, false>(arg_parser) ? 0
                                                                                            : -2;
    }
    else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms && save_unquant)
    {
        return run<ck_tile::half_t, ck_tile::int8_t, float, float, true, true>(arg_parser) ? 0 : -2;
    }
    else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms && save_unquant)
    {
        return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true, true>(arg_parser) ? 0 : -2;
    }
    else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms && save_unquant)
    {
        return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false, true>(arg_parser) ? 0 : -2;
    }
    else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
            !save_rms && save_unquant)
    {
        return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false, true>(arg_parser) ? 0 : -2;
    }

    return -3;
}
